{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Py89P9hGzVjR"
      },
      "outputs": [],
      "source": [
        "!pip install transformers"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install datasets"
      ],
      "metadata": {
        "id": "U2L3s1E1zcNj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install evaluate"
      ],
      "metadata": {
        "id": "dhvsP0b-zdsk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install openai"
      ],
      "metadata": {
        "id": "Tjz_YAHmzfWh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "GPT-4 ranking. The output of MindMap only choose the summary result (Output1)"
      ],
      "metadata": {
        "id": "YYIEH3uEzoSi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import openai\n",
        "import csv\n",
        "from time import sleep\n",
        "\n",
        "input_file = 'output.csv'\n",
        "output_file = 'output_ranking.csv'\n",
        "\n",
        "openai.api_key = \"YOUR_OPENAI_KEY\"\n",
        "\n",
        "def prompt_ranking(reference_text,output1_text,output2_text,output3_text,output4_text,output5_text,output6_text):\n",
        "  prompt_template = \"\"\"\n",
        "  Reference: {reference}\n",
        "\n",
        "  output1: {output1}\n",
        "\n",
        "  output2: {output2}\n",
        "\n",
        "  output3: {output3}\n",
        "\n",
        "  output4: {output4}\n",
        "\n",
        "  output5: {output5}\n",
        "\n",
        "  output6: {output6}\n",
        "\n",
        "  According to the facts of the disease and the drug and test recommendations in reference output, order the fact match of the output from highest to lowest.\n",
        "\n",
        "  The output followed the format bellow：\n",
        "\n",
        "  1. Output1\n",
        "  2. Output5\n",
        "  3. Output2\n",
        "  4. Output3\n",
        "  5. Output4\n",
        "  6. Output6\n",
        "  \"\"\"\n",
        "\n",
        "  prompt = prompt_template.format(reference=reference_text, output1=output1_text, output2=output2_text, output3=output3_text, output4=output4_text, output5=output5_text, output6=output6_text)\n",
        "  response = openai.ChatCompletion.create(\n",
        "      model=\"gpt-4\",\n",
        "      messages=[\n",
        "          {\"role\": \"system\", \"content\": \"You are a excellent.\"},\n",
        "          {\"role\": \"user\", \"content\": prompt}\n",
        "      ]\n",
        "  )\n",
        "\n",
        "  generated_text = response.choices[0].message.content\n",
        "  return generated_text\n",
        "\n",
        "with open(input_file,'r',newline=\"\") as f_input, open(output_file, 'a+', newline='') as f_output:\n",
        "  reader = csv.reader(f_input)\n",
        "  writer = csv.writer(f_output)\n",
        "\n",
        "  header = next(reader)\n",
        "  header.append(\"gpt4_ranking\")\n",
        "  writer.writerow(header)\n",
        "\n",
        "  for row in reader:\n",
        "    reference_text = [row[1].strip(\"\\n\")]\n",
        "    output1_text = [row[2].strip(\"\\n\")]\n",
        "    output2_text = [row[5].strip(\"\\n\")]\n",
        "    output3_text = [row[6].strip(\"\\n\")]\n",
        "    output4_text = [row[7].strip(\"\\n\")]\n",
        "    output5_text = [row[8].strip(\"\\n\")]\n",
        "    output6_text = [row[9].strip(\"\\n\")]\n",
        "\n",
        "    flag = 0\n",
        "    while flag == 0:\n",
        "      try:\n",
        "        str1 = prompt_ranking(reference_text,output1_text,output2_text,output3_text,output4_text,output5_text,output6_text)\n",
        "        flag = 1\n",
        "      except:\n",
        "        sleep(40)\n",
        "        str1 = prompt_ranking(reference_text,output1_text,output2_text,output3_text,output4_text,output5_text,output6_text)\n",
        "\n",
        "    print(str1)\n",
        "    print('\\n')\n",
        "\n",
        "    row.append(str1)\n",
        "    writer.writerow(row)\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "ysdUHGg6zql6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import re\n",
        "import csv\n",
        "\n",
        "input_file = 'output_ranking.csv'\n",
        "output_file = 'output_ranking_compute.csv'\n",
        "\n",
        "with open(input_file,'r',newline=\"\") as f_input, open(output_file, 'a+', newline='') as f_output:\n",
        "  reader = csv.reader(f_input)\n",
        "  writer = csv.writer(f_output)\n",
        "\n",
        "  header = next(reader)\n",
        "  header.extend([\"output1_ranking\",\"output2_ranking\",\"output3_ranking\",\"output4_ranking\",\"output5_ranking\",\"output6_ranking\"])\n",
        "  writer.writerow(header)\n",
        "\n",
        "  for row in reader:\n",
        "    output_text = row[10]\n",
        "    print(output_text)\n",
        "    matches = re.findall(r'(?m)^\\d+\\. (\\w+):', output_text)\n",
        "    # print(\"matches1\",matches)\n",
        "    if matches == []:\n",
        "      matches = re.findall(r'(?m)^\\d+\\. (\\w+)', output_text)\n",
        "    # print(\"matches2\",matches)\n",
        "    if matches == []:\n",
        "      matches = re.findall(r'\\d+\\.\\s+(\\w+)', output_text)\n",
        "    else:\n",
        "      if matches[0] == \"Output\":\n",
        "        matches = re.findall(r'(?m)^\\d+\\.\\s*(Output \\d+):', output_text)\n",
        "        # print('matches3',matches)\n",
        "        if matches == []:\n",
        "          matches = re.findall(r'\\d+\\.\\s+(Output\\s*\\d+)', output_text)\n",
        "\n",
        "    rankings = {}\n",
        "    for i, match in enumerate(matches):\n",
        "      rank = i + 1\n",
        "      output = match.replace(\" \",\"\")\n",
        "      print(output)\n",
        "      rankings[output] = rank\n",
        "\n",
        "    output_dic = {}\n",
        "\n",
        "    for output, rank in sorted(rankings.items(), key=lambda x: x[1]):\n",
        "      print(f\"{output}: {rank}\")\n",
        "      output_dic[output]=rank\n",
        "\n",
        "\n",
        "\n",
        "    print(output_dic['Output1'])\n",
        "    print('\\n')\n",
        "\n",
        "    row.extend([output_dic['Output1'],output_dic['Output2'],output_dic['Output3'],output_dic['Output4'],output_dic['Output5'],output_dic['Output6']])\n",
        "    writer.writerow(row)\n"
      ],
      "metadata": {
        "id": "9ndIJv2Ozt87"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "\n",
        "def calculate_column_average(file_path, column_name):\n",
        "    df = pd.read_csv(file_path)\n",
        "    average = df[column_name].mean()\n",
        "    return average\n",
        "\n",
        "column =[\"output1_ranking\",\"output2_ranking\",\"output3_ranking\",\"output4_ranking\",\"output5_ranking\",\"output6_ranking\"]\n",
        "\n",
        "file_path = './output_ranking_compute.csv'\n",
        "for column_name in column:\n",
        "  average = calculate_column_average(file_path, column_name)\n",
        "  print(f\"The average of column {column_name} is: {average}\")"
      ],
      "metadata": {
        "id": "pGngE-ZKzvc7"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}